from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
import torch as t
from CDTools.tools.cmath import *
from RIP_tools import *

#
# This section runs the numerical experiment designed
# to test the achievable resolution of our algorithm
#

nsims = 1
ntries = 1000
# Set this to false if you don't have a GPU
GPU = True

for probe_maxk in [100, 150, 200]:
    print('Studying Probe with Maximum k',probe_maxk)
    
    test_startk = int(probe_maxk * 0.37)*2
    test_endk = int(probe_maxk * 0.73)*2
    test_stepk = int((test_endk - test_startk) / 31) * 2

    field_size = int(np.ceil(test_endk + probe_maxk)) * 2
    # This sets the focal spot size of our BLR probe
    probe_fov_radius = (field_size * 4)//10

    # This defines that we do our error comparison within the central
    # half of the image.
    error_check_padding = 1/ 4 


    # This generates an ideal BLR probe with the given parameters
    nearfield = generate_blr_probe([field_size,field_size],
                                   probe_fov_radius,
                                   probe_maxk)

    # We will store the final loss and error from each reconstruction
    losses = []
    errors = []
    resolutions = range(test_startk, test_endk, test_stepk)
    
    for resolution in resolutions:
        print('Studying Resolution',resolution)
        res_losses = []
        res_errors = []

        # We give ourselves the opportunity to try multiple images at
        # each resolution, although nsims=1 here.
        for attempt in range(nsims):
            print('Trying Image',attempt+1,'of',nsims)
            attempt_losses = []
            attempt_errors = []

            # This generates a random image at our object resolution
            real = np.random.randn(resolution, resolution)
            imag = np.random.randn(resolution, resolution)
            original_image = real + 1j * imag
        
            original_image = complex_to_torch(original_image).to(dtype=t.float32)

            # All this does is upsample the probe in real space to avoid
            # aliasing in the probe-object interaction
            simulation_nearfield = expand_probe(nearfield, original_image)

            # This simulates the diffraction
            exit_wave = interact(original_image, simulation_nearfield)
            diffraction = propagate(exit_wave)
            intensities = measure(diffraction, nearfield.shape)

            # We normalize the pattern to 10 photons per pixel,
            # although no poisson noise is applied
            intensities /= t.sum(intensities)
            intensities *= 10 * probe_maxk**2 
            
        
            for initialization in range(ntries):
                # Each  retrieval uses a different randomized initialization
                retrieval, loss = reconstruct(intensities, nearfield, resolution, 0.4, 1000, GPU=GPU, schedule=True)
                ret = torch_to_complex(retrieval)
                im = torch_to_complex(original_image)

                # Now  we extract the RMS error from the central region
                # defined by the padding
                padding = int(resolution * error_check_padding)
                error, originals, comparisons = compare_images(ret, im, np.s_[padding:-padding,padding:-padding])

                # And we save out the final image error and diffraction loss
                attempt_losses.append(loss[-1])
                attempt_errors.append(error)
                print('Initialization',initialization+1,
                      'With Error', '%.2e' %error,
                      'And Loss', '%.2e'%loss[-1])
        
            res_losses.append(attempt_losses)
            res_errors.append(attempt_errors)
    
        errors.append(res_errors)
        losses.append(res_losses)
    
    
    results = {'resolutions': resolutions, 'probe': nearfield, 'losses':losses, 'errors':errors}
    with open('data/resolution trial '+ str(probe_maxk) + ' ' +str(nsims) + ' ' + str(ntries) + '.pickle', 'wb') as f:
        pickle.dump(results, f)
